#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <cassert>
#include <mpi.h>
#include "esmd_types.h"
#include "cell.h"
#include "coul_msm.h"
#include "comm.h"
#include "memory_cpp.hpp"
#include "list_cpp.hpp"
// #include "memory.h"
//#define real double
//#define esmd_malloc(x, y) malloc(x)

void get_msm_1d_range(msm_1d_range_t *range, msm_1d_part_t *part, int p, int lev) {
  if (lev == 0) {
    real llo_in = part->llen * p;
    real lhi_in = part->llen * (p + 1);
    real llo_out = llo_in - part->skin;
    real lhi_out = lhi_in + part->skin;
    range->s.local_lo = floor(llo_in * part->delinv[lev]);
    range->s.local_hi = floor(lhi_in * part->delinv[lev]) - 1;
    range->s.qout_lo = floor(llo_out * part->delinv[lev]) - max(part->order, part->ndirect);
    range->s.qout_hi = floor(lhi_out * part->delinv[lev]) - 1 + max(part->order, part->ndirect);
    range->s.qin_lo = range->s.qout_lo;
    range->s.qin_hi = range->s.qout_hi;
  } else {
    range->s.local_lo = p * part->pcnt[lev] + min(p, part->prem[lev]);
    range->s.local_hi = (p + 1) * part->pcnt[lev] + min(p + 1, part->prem[lev]) - 1;
    range->s.qin_lo = range->s.local_lo - part->ndirect;
    range->s.qin_hi = range->s.local_hi + part->ndirect;
    int llom1 = p * part->pcnt[lev - 1] + min(p, part->prem[lev - 1]);
    int lhim1 = (p + 1) * part->pcnt[lev - 1] + min(p + 1, part->prem[lev - 1]) - 1;
    if (lev == 1) {
      real llo_in0 = part->llen * p;
      real lhi_in0 = part->llen * (p + 1);
      llom1 = floor(llo_in0 * part->delinv[0]);
      lhim1 = floor(lhi_in0 * part->delinv[0]) - 1;
    }
    if (part->nmsm[lev] == part->nmsm[lev - 1]) {
      range->s.qout_lo = llom1 - part->order + 1;
      range->s.qout_hi = lhim1 + part->order - 1;
    } else {
      range->s.qout_lo = (llom1 - part->order + 1 + 1) >> 1;
      range->s.qout_hi = (lhim1 + part->order - 1) >> 1;
    }
  }
}
void generate_comm_schedule(msm_comm_schedule_1d_t *sched, msm_1d_part_t *part, int recv, int send, int p, int lev) {
  esmd::list<msm_comm_entry_1d_t> send_list(8, "msm/send list"), recv_list(8, "msm/recv list");
  // esmd_list_init(&send_list, 8, "msm/send list");
  // esmd_list_init(&recv_list, 8, "msm/recv list");
  int nsend = (send == E_QOUT_LO) ? (lev == 0 ? part->np : part->nproc[lev - 1]) : part->nproc[lev];
  int nrecv = (recv == E_QOUT_LO) ? (lev == 0 ? part->np : part->nproc[lev - 1]) : part->nproc[lev];
  int nmsm = part->nmsm[lev];
  msm_1d_range_t p_range;
  get_msm_1d_range(&p_range, part, p, lev);
  int lsend = 0, lrecv = 0;
  if (p < nsend) {
    int send_lo = p_range.v[send];
    int send_hi = p_range.v[send + 1];
    for (int q = 0; q < nrecv; q++) {
      msm_1d_range_t q_range;
      get_msm_1d_range(&q_range, part, q, lev);
      int recv_lo = q_range.v[recv];
      int recv_hi = q_range.v[recv + 1];
      int delta_st = 0, delta_ed = 0;
      while (recv_hi + delta_st * nmsm >= send_lo)
        delta_st--;
      while (recv_lo + delta_ed * nmsm <= send_hi)
        delta_ed++;
      for (int d = delta_st; d <= delta_ed; d++) {
        if (p == q && d == 0)
          continue;
        if (recv_hi + d * nmsm >= send_lo && recv_lo + d * nmsm <= send_hi) {
          msm_comm_entry_1d_t *entry = send_list.get_slot();
          entry->delta = -d;
          entry->pair = q;
          entry->start = max(recv_lo + d * nmsm, send_lo);
          entry->end = min(recv_hi + d * nmsm, send_hi);
          lsend += entry->end - entry->start + 1;
        }
      }
    }
  }
  if (p < nrecv) {
    int recv_lo = p_range.v[recv];
    int recv_hi = p_range.v[recv + 1];
    for (int q = 0; q < nsend; q++) {
      msm_1d_range_t q_range;
      get_msm_1d_range(&q_range, part, q, lev);
      int send_lo = q_range.v[send];
      int send_hi = q_range.v[send + 1];
      int delta_st = 0, delta_ed = 0;
      while (send_hi + delta_st * nmsm >= recv_lo)
        delta_st--;
      while (send_lo + delta_ed * nmsm <= recv_hi)
        delta_ed++;
      for (int d = delta_st; d <= delta_ed; d++) {
        if (p == q && d == 0)
          continue;
        if (recv_hi >= send_lo + d * nmsm && recv_lo <= send_hi + d * nmsm) {
          msm_comm_entry_1d_t *entry = recv_list.get_slot();
          entry->delta = d;
          entry->pair = q;
          entry->start = max(recv_lo, send_lo + d * nmsm);
          entry->end = min(recv_hi, send_hi + d * nmsm);
          lrecv += entry->end - entry->start + 1;
        }
      }
    }
  }
  sched->nrecv = recv_list.size();
  sched->lrecv = lrecv;
  sched->recv = recv_list.extract();
  sched->nsend = send_list.size();
  sched->lsend = lsend;
  sched->send = send_list.extract();
}
INLINE void print_sched(msm_comm_schedule_1d_t *sched, int p) {
  if (sched->nsend > 0) {
    printf("%d: send", p);
    for (int i = 0; i < sched->nsend; i++) {
      printf(" [%d:%d]->%d(%d)", sched->send[i].start, sched->send[i].end, sched->send[i].pair, sched->send[i].delta);
    }
    puts("");
  }
  if (sched->nrecv > 0) {
    printf("%d: recv", p);
    for (int i = 0; i < sched->nrecv; i++) {
      printf(" [%d:%d]<-%d(%d)", sched->recv[i].start, sched->recv[i].end, sched->recv[i].pair, sched->recv[i].delta);
    }
    puts("");
  }
}
void msm_1d_part_init(msm_1d_part_t *part, int np, int p, int order, int maxlev, int nlev, real glen, real cut, real skin) {
  part->order = order;
  part->nlev = nlev;
  part->maxlev = maxlev;
  part->glen = glen;
  part->llen = glen / np;
  part->skin = skin;
  part->cut = cut;
  part->np = np;
  for (int i = 0; i < maxlev; i++) {
    if (i < nlev)
      part->nmsm[i] = 1 << (nlev - i);
    else
      part->nmsm[i] = 1;
    part->delinv[i] = part->nmsm[i] / glen;
    part->del[i] = glen / part->nmsm[i];
    if (i == 0)
      part->nproc[i] = np;
    else
      part->nproc[i] = min(max(part->nmsm[i] / MIN_MSM_PER_PROC, 1), np);
    part->pcnt[i] = part->nmsm[i] / part->nproc[i];
    part->prem[i] = part->nmsm[i] % part->nproc[i];
  }

  part->ndirect = 2 * part->cut * part->delinv[0];
  int max_nrecv = 0, max_nsend = 0, max_lsend = 0, max_lrecv = 0;
  for (int i = 0; i < maxlev; i++) {
    msm_1d_range_t range;
    get_msm_1d_range(&range, part, p, i);
    part->all_lo[i] = min(range.s.qin_lo, range.s.qout_lo);
    part->all_hi[i] = max(range.s.qin_hi, range.s.qout_hi);
    part->local_lo[i] = range.s.local_lo;
    part->local_hi[i] = range.s.local_hi;

    part->qin_lo[i] = range.s.qin_lo;
    part->qin_hi[i] = range.s.qin_hi;
    part->qout_lo[i] = range.s.qout_lo;
    part->qout_hi[i] = range.s.qout_hi;

    //qout: qgrid generated by lower level grid, egrid required by lower level grid
    //qin: qgrid required by current level direct, egrid generated by current level direct
    //in restriction: each qgrid generated should be summed up, qout->qlocal
    //in direct: we need a forward comm of qgrid, qlocal->qin
    //in prolongation: we need sum up current level egrid: e[qin]->e[local], also, forward egrid e[local]->e[qout]
    //msm_comm_schedule_1d_t reverse_q_sched;
    generate_comm_schedule(part->rev_q_sched + i, part, E_LOCAL_LO, E_QOUT_LO, p, i);
    generate_comm_schedule(part->fwd_q_sched + i, part, E_QIN_LO, E_LOCAL_LO, p, i);
    generate_comm_schedule(part->fwd_e_sched + i, part, E_QOUT_LO, E_LOCAL_LO, p, i);
    max_nrecv = max(max_nrecv, part->rev_q_sched[i].nrecv);
    max_nrecv = max(max_nrecv, part->fwd_q_sched[i].nrecv);
    max_nrecv = max(max_nrecv, part->fwd_e_sched[i].nrecv);

    max_nsend = max(max_nsend, part->rev_q_sched[i].nsend);
    max_nsend = max(max_nsend, part->fwd_q_sched[i].nsend);
    max_nsend = max(max_nsend, part->fwd_e_sched[i].nsend);

    max_lrecv = max(max_lrecv, part->rev_q_sched[i].lrecv);
    max_lrecv = max(max_lrecv, part->fwd_q_sched[i].lrecv);
    max_lrecv = max(max_lrecv, part->fwd_e_sched[i].lrecv);

    max_lsend = max(max_lsend, part->rev_q_sched[i].lsend);
    max_lsend = max(max_lsend, part->fwd_q_sched[i].lsend);
    max_lsend = max(max_lsend, part->fwd_e_sched[i].lsend);
  }
  part->max_nsend = max_nsend;
  part->max_nrecv = max_nrecv;
  part->max_lsend = max_lsend;
  part->max_lrecv = max_lrecv;
}

void get_gv_direct(msm_grid_t *grid, int lev) {
  int order = grid->order;
  vec<int> ndirect = {grid->xpart.ndirect, grid->ypart.ndirect, grid->zpart.ndirect};
  vec<int> len;
  vecaddv(len, ndirect, ndirect);
  vecadd(len, len, 1);
  real *g_direct = esmd::allocate<real>(len.vol(), "msm/g direct");
  grid->g_direct[lev] = g_direct;
  real *vs_direct = esmd::allocate<real>(len.vol(), "msm/v direct");
  grid->vs_direct[lev] = vs_direct;
  real *v_direct[6];
  for (int i = 0; i < 6; i++) {
    v_direct[i] = esmd::allocate<real>(len.vol(), "msm/v direct");
    grid->v_direct[lev][i] = v_direct[i];
  }
  real rrho = (1 << lev) * grid->rcut;
  real rrho2 = rrho * rrho;
  for (int i = -ndirect.x; i <= ndirect.x; i++) {
    real dx = i / grid->xpart.delinv[lev];
    int ii = i + ndirect.x;
    for (int j = -ndirect.y; j <= ndirect.y; j++) {
      real dy = j / grid->ypart.delinv[lev];
      int jj = j + ndirect.y;
      for (int k = -ndirect.z; k <= ndirect.z; k++) {
        real dz = k / grid->zpart.delinv[lev];
        int kk = k + ndirect.z;
        real rsq = dx * dx + dy * dy + dz * dz;
        real r = sqrt(rsq);
        real rho = r / rrho;
        int idir = (ii * len.y + jj) * len.z + kk;
        g_direct[idir] = msm_gamma(rho, order) / rrho - msm_gamma(rho * 0.5, order) / (2 * rrho);
        real dg;
        if (rsq != 0) {
          dg = -(msm_dgamma(rho, order) / rrho2 - msm_dgamma(rho * 0.5, order) / (4 * rrho2)) / r;
        } else {
          dg = 0;
        }
        vs_direct[idir] = dg;
        v_direct[0][idir] = dg * dx * dx;
        v_direct[1][idir] = dg * dy * dy;
        v_direct[2][idir] = dg * dz * dz;
        v_direct[3][idir] = dg * dx * dy;
        v_direct[4][idir] = dg * dx * dz;
        v_direct[5][idir] = dg * dy * dz;
      }
    }
  }
}

void msm_grid_init(msm_grid_t *grid, mpp_t *mpp, vec<int> * nlev, int order, real rcut, vec<real> * skin, real coul_const) {
  grid->mpp = mpp;
  boxcpy(grid->lbox, mpp->lbox);
  boxcpy(grid->gbox, mpp->gbox);

  box<real> *lbox = &mpp->lbox;
  box<real> *gbox = &mpp->gbox;
  grid->order = order;
  grid->rcut = rcut;
  veccpy(grid->skin, *skin);
  grid->maxlev = 0;
  if (nlev->x > grid->maxlev)
    grid->maxlev = nlev->x;
  if (nlev->y > grid->maxlev)
    grid->maxlev = nlev->y;
  if (nlev->z > grid->maxlev)
    grid->maxlev = nlev->z;
  msm_1d_part_init(&grid->xpart, mpp->dim.x, mpp->loc.x, order, grid->maxlev, nlev->x, gbox->hi.x - gbox->lo.x, rcut, skin->x);
  msm_1d_part_init(&grid->ypart, mpp->dim.y, mpp->loc.y, order, grid->maxlev, nlev->y, gbox->hi.y - gbox->lo.y, rcut, skin->y);
  msm_1d_part_init(&grid->zpart, mpp->dim.z, mpp->loc.z, order, grid->maxlev, nlev->z, gbox->hi.z - gbox->lo.z, rcut, skin->z);
  vecset1(grid->max_nall, 0);
  MPI_Comm_split(mpp->comm, mpp->loc.y * mpp->dim.z + mpp->loc.z, mpp->loc.x, &grid->xcomm);
  MPI_Comm_split(mpp->comm, mpp->loc.x * mpp->dim.z + mpp->loc.z, mpp->loc.y, &grid->ycomm);
  MPI_Comm_split(mpp->comm, mpp->loc.x * mpp->dim.y + mpp->loc.y, mpp->loc.z, &grid->zcomm);

  for (int i = 0; i < grid->maxlev; i++) {
    grid->nall[i].x = grid->xpart.all_hi[i] - grid->xpart.all_lo[i] + 1;
    grid->nall[i].y = grid->ypart.all_hi[i] - grid->ypart.all_lo[i] + 1;
    grid->nall[i].z = grid->zpart.all_hi[i] - grid->zpart.all_lo[i] + 1;

    grid->nlocal[i].x = grid->xpart.local_hi[i] - grid->xpart.local_lo[i] + 1;
    grid->nlocal[i].y = grid->ypart.local_hi[i] - grid->ypart.local_lo[i] + 1;
    grid->nlocal[i].z = grid->zpart.local_hi[i] - grid->zpart.local_lo[i] + 1;

    grid->low[i].x = grid->xpart.all_lo[i];
    grid->low[i].y = grid->ypart.all_lo[i];
    grid->low[i].z = grid->zpart.all_lo[i];

    grid->high[i].x = grid->xpart.all_hi[i];
    grid->high[i].y = grid->ypart.all_hi[i];
    grid->high[i].z = grid->zpart.all_hi[i];
    grid->q[i] = esmd::allocate<real>(grid->nall[i].vol(), "msm/qgrid");
    grid->e[i] = esmd::allocate<real>(grid->nall[i].vol(), "msm/egrid");
    vecmaxv(grid->max_nall, grid->max_nall, grid->nall[i]);
    get_gv_direct(grid, i);
    if (mpp->loc.x < grid->xpart.nproc[i] && mpp->loc.y < grid->ypart.nproc[i] && mpp->loc.z < grid->zpart.nproc[i]) {
      grid->active[i] = 1;
    } else {
      grid->active[i] = 0;
    }

    grid->coul_const = coul_const;
  }

  //for (int i = 0; i < )
  int max_lsend = max3(grid->xpart.max_lsend * grid->max_nall.y * grid->max_nall.z,
                       grid->ypart.max_lsend * grid->max_nall.x * grid->max_nall.z,
                       grid->zpart.max_lsend * grid->max_nall.x * grid->max_nall.y);
  int max_lrecv = max3(grid->xpart.max_lrecv * grid->max_nall.y * grid->max_nall.z,
                       grid->ypart.max_lrecv * grid->max_nall.x * grid->max_nall.z,
                       grid->zpart.max_lrecv * grid->max_nall.x * grid->max_nall.y);
  int max_nsend = max3(grid->xpart.max_nsend, grid->ypart.max_nsend, grid->zpart.max_nsend);
  int max_nrecv = max3(grid->xpart.max_nrecv, grid->ypart.max_nrecv, grid->zpart.max_nrecv);
  grid->send_buf = esmd::allocate<real>(max_lsend * 20, "msm/send buf");
  //printf("send_buf: %p %p\n", grid->send_buf, grid->send_buf + max_lsend);
  grid->recv_buf = esmd::allocate<real>(max_lrecv * 20, "msm/recv buf");
  //printf("recv_buf: %p %p\n", grid->recv_buf, grid->recv_buf + max_lrecv);
  grid->rreq = esmd::allocate<MPI_Request>(max_nrecv * 20, "msm/recv req");
  grid->rstat = esmd::allocate<MPI_Status>(max_nrecv * 20, "msm/recv stat");
  grid->sreq = esmd::allocate<MPI_Request>(max_nsend * 20, "msm/send req");
  grid->sstat = esmd::allocate<MPI_Status>(max_nsend * 20, "msm/send stat");
}
#define sign(x) ((x > 0) - (x < 0))
void compute_phi1(vec<real> * ret, vec<real> * d, int order) {
  vec<real> absd;
  vecset3(absd, fabs(d->x), fabs(d->y), fabs(d->z));
  vec<int> iabsd;
  vecset3(iabsd, floor(absd.x), floor(absd.y), floor(absd.z));
  const real *coefsx = PHICOEF[order >> 1][min(iabsd.x, (order - 1) >> 1)];
  const real *coefsy = PHICOEF[order >> 1][min(iabsd.y, (order - 1) >> 1)];
  const real *coefsz = PHICOEF[order >> 1][min(iabsd.z, (order - 1) >> 1)];

  ret->x = coefsx[0];
  ret->y = coefsy[0];
  ret->z = coefsz[0];
  for (int i = 1; i < order; i++) {
    ret->x = ret->x * absd.x + coefsx[i];
    ret->y = ret->y * absd.y + coefsy[i];
    ret->z = ret->z * absd.z + coefsz[i];
  }
}

void compute_dphi1(vec<real> * ret, vec<real> * d, int order) {
  vec<real> absd;
  vecset3(absd, fabs(d->x), fabs(d->y), fabs(d->z));
  vec<int> iabsd;
  vecset3(iabsd, floor(absd.x), floor(absd.y), floor(absd.z));
  const real *coefsx = DPHICOEF[order >> 1][min(iabsd.x, (order - 1) >> 1)];
  const real *coefsy = DPHICOEF[order >> 1][min(iabsd.y, (order - 1) >> 1)];
  const real *coefsz = DPHICOEF[order >> 1][min(iabsd.z, (order - 1) >> 1)];

  ret->x = coefsx[0];
  ret->y = coefsy[0];
  ret->z = coefsz[0];
  for (int i = 1; i < order - 1; i++) {
    ret->x = ret->x * absd.x + coefsx[i];
    ret->y = ret->y * absd.y + coefsy[i];
    ret->z = ret->z * absd.z + coefsz[i];
  }
  ret->x = ret->x * sign(d->x);
  ret->y = ret->y * sign(d->y);
  ret->z = ret->z * sign(d->z);
}
void compute_phis(vec<real> * ret, vec<real> * d, int order) {
  for (int i = 0; i < order; i++) {
    vec<real> cur_d;
    vecadd(cur_d, *d, i - ((order - 1) >> 1));
    compute_phi1(ret + i, &cur_d, order);
  }
}
void compute_dphis(vec<real> * ret, vec<real> * d, int order) {
  for (int i = 0; i < order; i++) {
    vec<real> cur_d;
    vecadd(cur_d, *d, i - ((order - 1) >> 1));
    compute_dphi1(ret + i, &cur_d, order);
  }
}
#define MAX_ORDER 12

size_t pack_grid_brick(real *buf, real *q, vec<int> * low, vec<int> * nall, box<int> * box) {
  vec<int> npack;
  vecsubv(npack, box->hi, box->lo);
  for (int ii = 0; ii < npack.x; ii++) {
    int i = box->lo.x + ii - low->x;
    for (int jj = 0; jj < npack.y; jj++) {
      int j = box->lo.y + jj - low->y;
      for (int kk = 0; kk < npack.z; kk++) {
        int k = box->lo.z + kk - low->z;
        buf[(ii * npack.y + jj) * npack.z + kk] = q[(i * nall->y + j) * nall->z + k];
      }
    }
  }
  return vecvol(npack);
}

size_t unpack_add_grid_brick(real *buf, real *q, vec<int> * low, vec<int> * nall, box<int> * box) {
  vec<int> npack;
  vecsubv(npack, box->hi, box->lo);
  for (int ii = 0; ii < npack.x; ii++) {
    int i = box->lo.x + ii - low->x;
    for (int jj = 0; jj < npack.y; jj++) {
      int j = box->lo.y + jj - low->y;
      for (int kk = 0; kk < npack.z; kk++) {
        int k = box->lo.z + kk - low->z;
        q[(i * nall->y + j) * nall->z + k] += buf[(ii * npack.y + jj) * npack.z + kk];
      }
    }
  }
  return vecvol(npack);
}

size_t unpack_copy_grid_brick(real *buf, real *q, vec<int> * low, vec<int> * nall, box<int> * box) {
  vec<int> npack;
  vecsubv(npack, box->hi, box->lo);
  for (int ii = 0; ii < npack.x; ii++) {
    int i = box->lo.x + ii - low->x;
    for (int jj = 0; jj < npack.y; jj++) {
      int j = box->lo.y + jj - low->y;
      for (int kk = 0; kk < npack.z; kk++) {
        int k = box->lo.z + kk - low->z;
        q[(i * nall->y + j) * nall->z + k] = buf[(ii * npack.y + jj) * npack.z + kk];
      }
    }
  }
  return vecvol(npack);
}
static const int basetag = 0x2000;
void make_recv_requests(msm_comm_schedule_1d_t *sched, msm_grid_t *grid, int area, int axis, MPI_Comm comm) {
  int tot_recv = 0;
  for (int i = 0; i < sched->nrecv; i++) {
    msm_comm_entry_1d_t *entry = sched->recv + i;
    int nrecv = (entry->end - entry->start + 1) * area;
    MPI_Irecv(grid->recv_buf + tot_recv, nrecv, mpi_real, entry->pair, basetag + axis * 128 + entry->delta, comm, grid->rreq + i);
    tot_recv += nrecv;
  }
}
//#define DEBUG_COMM
static const char *axis_name = "xyz";
void pack_send_grid(msm_comm_schedule_1d_t *sched, msm_grid_t *grid, real *g, box<int> * box0, int axis, int lev, MPI_Comm comm) {
//printf("pack axis=%d\n", axis);
#ifdef DEBUG_COMM
  int rank;
  MPI_Comm_rank(comm, &rank);
#endif
  int tot_send = 0;
  box<int> pack_box;
  boxcpy(pack_box, *box0);
  int *pack_lo = ((int *)&pack_box) + axis;
  int *pack_hi = ((int *)&pack_box) + axis + 3;
  for (int i = 0; i < sched->nsend; i++) {
    msm_comm_entry_1d_t *entry = sched->send + i;
    *pack_lo = entry->start;
    *pack_hi = entry->end + 1;
#ifdef DEBUG_COMM
    printf("%c: %d sending (%d, %d, %d) -- (%d, %d, %d) --> %d, tag=%d\n", axis_name[axis], rank,
           pack_box.lo.x, pack_box.lo.y, pack_box.lo.z,
           pack_box.hi.x, pack_box.hi.y, pack_box.hi.z,
           entry->pair, basetag + axis * 128 + entry->delta);
#endif
    int nsend = pack_grid_brick(grid->send_buf + tot_send, g, grid->low + lev, grid->nall + lev, &pack_box);
    //printf("send: %p %p %p\n", grid->send_buf + tot_send, grid->send_buf + tot_send + nsend, grid->sreq + i);
    MPI_Isend(grid->send_buf + tot_send, nsend, mpi_real, entry->pair, basetag + axis * 128 + entry->delta, comm, grid->sreq + i);
    tot_send += nsend;
  }
}
void recv_add_grid(msm_comm_schedule_1d_t *sched, msm_grid_t *grid, real *g, box<int> * box0, int axis, int lev, MPI_Comm comm) {
#ifdef DEBUG_COMM
  int rank;
  MPI_Comm_rank(comm, &rank);
#endif
  int tot_recv = 0;
  box<int> unpack_box;
  boxcpy(unpack_box, *box0);
  int *unpack_lo = ((int *)&unpack_box) + axis;
  int *unpack_hi = ((int *)&unpack_box) + axis + 3;
  for (int i = 0; i < sched->nrecv; i++) {
    msm_comm_entry_1d_t *entry = sched->recv + i;
    *unpack_lo = entry->start;
    *unpack_hi = entry->end + 1;
#ifdef DEBUG_COMM
    printf("%c: %d receiving+ (%d, %d, %d) -- (%d, %d, %d) <-- %d, tag=%d\n", axis_name[axis], rank,
           unpack_box.lo.x, unpack_box.lo.y, unpack_box.lo.z,
           unpack_box.hi.x, unpack_box.hi.y, unpack_box.hi.z,
           entry->pair, basetag + axis * 128 + entry->delta);
#endif
    int nrecv = unpack_add_grid_brick(grid->recv_buf + tot_recv, g, grid->low + lev, grid->nall + lev, &unpack_box);
    tot_recv += nrecv;
  }
}
void recv_copy_grid(msm_comm_schedule_1d_t *sched, msm_grid_t *grid, real *g, box<int> * box0, int axis, int lev, MPI_Comm comm) {
#ifdef DEBUG_COMM
  int rank;
  MPI_Comm_rank(comm, &rank);
#endif
  int tot_recv = 0;
  box<int> unpack_box;
  boxcpy(unpack_box, *box0);
  int *unpack_lo = ((int *)&unpack_box) + axis;
  int *unpack_hi = ((int *)&unpack_box) + axis + 3;
  for (int i = 0; i < sched->nrecv; i++) {
    msm_comm_entry_1d_t *entry = sched->recv + i;
    *unpack_lo = entry->start;
    *unpack_hi = entry->end + 1;
#ifdef DEBUG_COMM
    printf("%c: %d receiving (%d, %d, %d) -- (%d, %d, %d) <-- %d, tag=%d\n", axis_name[axis], rank,
           unpack_box.lo.x, unpack_box.lo.y, unpack_box.lo.z,
           unpack_box.hi.x, unpack_box.hi.y, unpack_box.hi.z,
           entry->pair, basetag + axis * 128 + entry->delta);
#endif
    int nrecv = unpack_copy_grid_brick(grid->recv_buf + tot_recv, g, grid->low + lev, grid->nall + lev, &unpack_box);
    tot_recv += nrecv;
  }
}
void reverse_q(msm_grid_t *grid, mpp_t *mpp, real *g, int lev) {
#ifdef DEBUG_COMM
  puts("reverse_q");
  MPI_Barrier(MPI_COMM_WORLD);
#endif
  //x first
  box<int> comm_box_x = {{0, grid->low[lev].y, grid->low[lev].z}, {0, grid->high[lev].y + 1, grid->high[lev].z + 1}};
  make_recv_requests(grid->xpart.rev_q_sched + lev, grid, grid->nall[lev].y * grid->nall[lev].z, 0, grid->xcomm);
  pack_send_grid(grid->xpart.rev_q_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.rev_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_add_grid(grid->xpart.rev_q_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.rev_q_sched[lev].nsend, grid->sreq, grid->sstat);

  //y then
  box<int> comm_box_y = {{grid->xpart.local_lo[lev], 0, grid->low[lev].z}, {grid->xpart.local_hi[lev] + 1, 0, grid->high[lev].z + 1}};
  make_recv_requests(grid->ypart.rev_q_sched + lev, grid, grid->nlocal[lev].x * grid->nall[lev].z, 1, grid->ycomm);
  pack_send_grid(grid->ypart.rev_q_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.rev_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_add_grid(grid->ypart.rev_q_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.rev_q_sched->nsend, grid->sreq, grid->sstat);

  //z last
  box<int> comm_box_z = {{grid->xpart.local_lo[lev], grid->ypart.local_lo[lev], 0}, {grid->xpart.local_hi[lev] + 1, grid->ypart.local_hi[lev] + 1, 0}};
  make_recv_requests(grid->zpart.rev_q_sched + lev, grid, grid->nlocal[lev].x * grid->nlocal[lev].y, 2, grid->zcomm);
  pack_send_grid(grid->zpart.rev_q_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.rev_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_add_grid(grid->zpart.rev_q_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.rev_q_sched->nsend, grid->sreq, grid->sstat);
#ifdef DEBUG_COMM
  MPI_Barrier(MPI_COMM_WORLD);
#endif
}

void forward_q(msm_grid_t *grid, mpp_t *mpp, real *g, int lev) {
#ifdef DEBUG_COMM
  puts("forward_q");
  MPI_Barrier(MPI_COMM_WORLD);
#endif

  //z first
  box<int> comm_box_z = {{grid->xpart.local_lo[lev], grid->ypart.local_lo[lev], 0},
                         {grid->xpart.local_hi[lev] + 1, grid->ypart.local_hi[lev] + 1, 0}};
  make_recv_requests(grid->zpart.fwd_q_sched + lev, grid, grid->nlocal[lev].x * grid->nlocal[lev].y, 2, grid->zcomm);
  pack_send_grid(grid->zpart.fwd_q_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.fwd_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->zpart.fwd_q_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.fwd_q_sched[lev].nsend, grid->sreq, grid->sstat);

  //y then
  box<int> comm_box_y = {{grid->xpart.local_lo[lev], 0, grid->low[lev].z},
                         {grid->xpart.local_hi[lev] + 1, 0, grid->high[lev].z + 1}};
  make_recv_requests(grid->ypart.fwd_q_sched + lev, grid, grid->nlocal[lev].x * grid->nall[lev].z, 1, grid->ycomm);
  pack_send_grid(grid->ypart.fwd_q_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.fwd_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->ypart.fwd_q_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.fwd_q_sched[lev].nsend, grid->sreq, grid->sstat);

  //x last
  box<int> comm_box_x = {{0, grid->low[lev].y, grid->low[lev].z}, {0, grid->high[lev].y + 1, grid->high[lev].z + 1}};
  make_recv_requests(grid->xpart.fwd_q_sched + lev, grid, grid->nall[lev].y * grid->nall[lev].z, 0, grid->xcomm);
  pack_send_grid(grid->xpart.fwd_q_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.fwd_q_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->xpart.fwd_q_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.fwd_q_sched[lev].nsend, grid->sreq, grid->sstat);
#ifdef DEBUG_COMM
  MPI_Barrier(MPI_COMM_WORLD);
#endif
}

void forward_e(msm_grid_t *grid, mpp_t *mpp, real *g, int lev) {
#ifdef DEBUG_COMM
  puts("forward_e");
  MPI_Barrier(MPI_COMM_WORLD);
#endif
  //z first
  box<int> comm_box_z = {{grid->xpart.local_lo[lev], grid->ypart.local_lo[lev], 0},
                         {grid->xpart.local_hi[lev] + 1, grid->ypart.local_hi[lev] + 1, 0}};
  make_recv_requests(grid->zpart.fwd_e_sched + lev, grid, grid->nlocal[lev].x * grid->nlocal[lev].y, 2, grid->zcomm);
  pack_send_grid(grid->zpart.fwd_e_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.fwd_e_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->zpart.fwd_e_sched + lev, grid, g, &comm_box_z, 2, lev, grid->zcomm);
  MPI_Waitall(grid->zpart.fwd_e_sched[lev].nsend, grid->sreq, grid->sstat);

  //y then
  box<int> comm_box_y = {{grid->xpart.local_lo[lev], 0, grid->low[lev].z},
                         {grid->xpart.local_hi[lev] + 1, 0, grid->high[lev].z + 1}};
  make_recv_requests(grid->ypart.fwd_e_sched + lev, grid, grid->nlocal[lev].x * grid->nall[lev].z, 1, grid->ycomm);
  pack_send_grid(grid->ypart.fwd_e_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.fwd_e_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->ypart.fwd_e_sched + lev, grid, g, &comm_box_y, 1, lev, grid->ycomm);
  MPI_Waitall(grid->ypart.fwd_e_sched[lev].nsend, grid->sreq, grid->sstat);

  //x last
  box<int> comm_box_x = {{0, grid->low[lev].y, grid->low[lev].z}, {0, grid->high[lev].y + 1, grid->high[lev].z + 1}};
  make_recv_requests(grid->xpart.fwd_e_sched + lev, grid, grid->nall[lev].y * grid->nall[lev].z, 0, grid->xcomm);
  pack_send_grid(grid->xpart.fwd_e_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.fwd_e_sched[lev].nrecv, grid->rreq, grid->rstat);
  recv_copy_grid(grid->xpart.fwd_e_sched + lev, grid, g, &comm_box_x, 0, lev, grid->xcomm);
  MPI_Waitall(grid->xpart.fwd_e_sched[lev].nsend, grid->sreq, grid->sstat);
#ifdef DEBUG_COMM
  MPI_Barrier(MPI_COMM_WORLD);
#endif
}

void restriction(msm_grid_t *grid, int lev) {
  int order = grid->order;
  int p = order - 1;
  vec<int> qout_lo = {grid->xpart.qout_lo[lev + 1], grid->ypart.qout_lo[lev + 1], grid->zpart.qout_lo[lev + 1]};
  vec<int> qout_hi = {grid->xpart.qout_hi[lev + 1], grid->ypart.qout_hi[lev + 1], grid->zpart.qout_hi[lev + 1]};
  vec<int> factor;
  factor.x = grid->xpart.nmsm[lev] == grid->xpart.nmsm[lev + 1] ? 1 : 2;
  factor.y = grid->ypart.nmsm[lev] == grid->ypart.nmsm[lev + 1] ? 1 : 2;
  factor.z = grid->zpart.nmsm[lev] == grid->zpart.nmsm[lev + 1] ? 1 : 2;

  int nus[MAX_ORDER], nnu = 0;
  vec<real> phi1d[MAX_ORDER + 1];
  for (int i = -order; i <= order; i++) {
    if ((i & 1) == 0 && i != 0)
      continue;
    nus[nnu] = i;
    vec<real> d = {i * 1.0 / factor.x, i * 1.0 / factor.y, i * 1.0 / factor.z};
    compute_phi1(phi1d + nnu, &d, order);
    nnu++;
  }
  real *qupper = grid->q[lev + 1];
  real *qlower = grid->q[lev];

  for (int i = qout_lo.x; i <= qout_hi.x; i++) {
    int ii = i - grid->low[lev + 1].x;
    int cx = i * factor.x - grid->low[lev].x;
    for (int j = qout_lo.y; j <= qout_hi.y; j++) {
      int jj = j - grid->low[lev + 1].y;
      int cy = j * factor.y - grid->low[lev].y;
      for (int k = qout_lo.z; k <= qout_hi.z; k++) {
        int kk = k - grid->low[lev + 1].z;
        int cz = k * factor.z - grid->low[lev].z;
        real qsum = 0;
        for (int dx = 0; dx < nnu; dx++) {
          int x = cx + nus[dx];
          if (x + grid->low[lev].x < grid->xpart.local_lo[lev] || x + grid->low[lev].x > grid->xpart.local_hi[lev])
            continue;
          real phix = phi1d[dx].x;
          for (int dy = 0; dy < nnu; dy++) {
            int y = cy + nus[dy];
            if (y + grid->low[lev].y < grid->ypart.local_lo[lev] || y + grid->low[lev].y > grid->ypart.local_hi[lev])
              continue;
            real phixy = phix * phi1d[dy].y;
            for (int dz = 0; dz < nnu; dz++) {
              int z = cz + nus[dz];
              if (z + grid->low[lev].z < grid->zpart.local_lo[lev] || z + grid->low[lev].z > grid->zpart.local_hi[lev])
                continue;
              real phixyz = phixy * phi1d[dz].z;
              qsum += qlower[(x * grid->nall[lev].y + y) * grid->nall[lev].z + z] * phixyz;
            }
          }
        }
        qupper[(ii * grid->nall[lev + 1].y + jj) * grid->nall[lev + 1].z + kk] = qsum;
      }
    }
  }
}

void direct(msm_grid_t *grid, int lev, mdstat_t *stat) {
  if (!grid->active[lev])
    return;
  vec<int> ndirect = {grid->xpart.ndirect, grid->ypart.ndirect, grid->zpart.ndirect};
  vec<int> dlen;
  vecscaleadd(dlen, ndirect, 2, 1);
  vec<int> local_lo = {grid->xpart.local_lo[lev], grid->ypart.local_lo[lev], grid->zpart.local_lo[lev]};
  vec<int> local_hi = {grid->xpart.local_hi[lev], grid->ypart.local_hi[lev], grid->zpart.local_hi[lev]};
  real *g_direct = grid->g_direct[lev];
  real **v_direct = grid->v_direct[lev];
  real *vs_direct = grid->vs_direct[lev];
  vec<real> del = {grid->xpart.del[lev], grid->ypart.del[lev], grid->zpart.del[lev]};
  real *q = grid->q[lev];
  real *e = grid->e[lev];
  vec<int> nall = vecarr(grid->nall[lev]);
  real etot = 0;
  real vtot[6] = {0, 0, 0, 0, 0, 0};
  for (int i = local_lo.x; i <= local_hi.x; i++) {
    int ii = i - grid->low[lev].x;
    for (int j = local_lo.y; j <= local_hi.y; j++) {
      int jj = j - grid->low[lev].y;
      for (int k = local_lo.z; k <= local_hi.z; k++) {
        int kk = k - grid->low[lev].z;
        real esum = 0;
        real vsum[6] = {0, 0, 0, 0, 0, 0};
        for (int dx = -ndirect.x; dx <= ndirect.x; dx++) {
          int x = ii + dx;
          int ddx = dx + ndirect.x;
          real fdx = dx * del.x;
          for (int dy = -ndirect.y; dy <= ndirect.y; dy++) {
            int y = jj + dy;
            int ddy = dy + ndirect.y;
            real fdy = dy * del.y;
            for (int dz = -ndirect.z; dz <= ndirect.z; dz++) {
              int z = kk + dz;
              int ddz = dz + ndirect.z;
              real fdz = dz * del.z;
              //printf("%d %d %d from %d %d %d %f\n", i, j, k, x + grid->low[lev].x, y + grid->low[lev].y, z + grid->low[lev].z, q[(x * nall.y + y) * nall.z + z] * g_direct[(ddx * dlen.y + ddy) * dlen.z + ddz]);
              esum += q[(x * nall.y + y) * nall.z + z] * g_direct[(ddx * dlen.y + ddy) * dlen.z + ddz];
              real vs = vs_direct[(ddx * dlen.y + ddy) * dlen.z + ddz];
              vsum[0] += q[(x * nall.y + y) * nall.z + z] * vs * fdx * fdx;
              vsum[1] += q[(x * nall.y + y) * nall.z + z] * vs * fdy * fdy;
              vsum[2] += q[(x * nall.y + y) * nall.z + z] * vs * fdz * fdz;
              vsum[3] += q[(x * nall.y + y) * nall.z + z] * vs * fdx * fdy;
              vsum[4] += q[(x * nall.y + y) * nall.z + z] * vs * fdx * fdz;
              vsum[5] += q[(x * nall.y + y) * nall.z + z] * vs * fdy * fdz;
            }
          }
        }
        e[(ii * nall.y + jj) * nall.z + kk] = esum;
        etot += q[(ii * nall.y + jj) * nall.z + kk] * esum;
        vtot[0] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[0];
        vtot[1] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[1];
        vtot[2] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[2];
        vtot[3] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[3];
        vtot[4] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[4];
        vtot[5] += q[(ii * nall.y + jj) * nall.z + kk] * vsum[5];
      }
    }
  }
  stat->ecoul += etot;
  stat->virial[0] += vtot[0];
  stat->virial[1] += vtot[1];
  stat->virial[2] += vtot[2];
  stat->virial[3] += vtot[3];
  stat->virial[4] += vtot[4];
  stat->virial[5] += vtot[5];
  //printf("etot=%.10f\n", etot);
}
void prolong(msm_grid_t *grid, int lev) {
  int order = grid->order;
  int p = order - 1;
  vec<int> qout_lo = {grid->xpart.qout_lo[lev + 1], grid->ypart.qout_lo[lev + 1], grid->zpart.qout_lo[lev + 1]};
  vec<int> qout_hi = {grid->xpart.qout_hi[lev + 1], grid->ypart.qout_hi[lev + 1], grid->zpart.qout_hi[lev + 1]};
  vec<int> factor;
  factor.x = grid->xpart.nmsm[lev] == grid->xpart.nmsm[lev + 1] ? 1 : 2;
  factor.y = grid->ypart.nmsm[lev] == grid->ypart.nmsm[lev + 1] ? 1 : 2;
  factor.z = grid->zpart.nmsm[lev] == grid->zpart.nmsm[lev + 1] ? 1 : 2;

  int nus[MAX_ORDER], nnu = 0;
  vec<real> phi1d[MAX_ORDER + 1];
  for (int i = -order; i <= order; i++) {
    if ((i & 1) == 0 && i != 0)
      continue;
    nus[nnu] = i;
    vec<real> d = {i * 1.0 / factor.x, i * 1.0 / factor.y, i * 1.0 / factor.z};
    compute_phi1(phi1d + nnu, &d, order);
    nnu++;
  }
  real *eupper = grid->e[lev + 1];
  real *elower = grid->e[lev];
  for (int i = qout_lo.x; i <= qout_hi.x; i++) {
    int ii = i - grid->low[lev + 1].x;
    int cx = i * factor.x - grid->low[lev].x;
    for (int j = qout_lo.y; j <= qout_hi.y; j++) {
      int jj = j - grid->low[lev + 1].y;
      int cy = j * factor.y - grid->low[lev].y;
      for (int k = qout_lo.z; k <= qout_hi.z; k++) {
        int kk = k - grid->low[lev + 1].z;
        int cz = k * factor.z - grid->low[lev].z;
        real etmp = eupper[(ii * grid->nall[lev + 1].y + jj) * grid->nall[lev + 1].z + kk];
        for (int dx = 0; dx < nnu; dx++) {
          int x = cx + nus[dx];
          if (x + grid->low[lev].x < grid->xpart.local_lo[lev] || x + grid->low[lev].x > grid->xpart.local_hi[lev])
            continue;
          real phix = phi1d[dx].x;
          for (int dy = 0; dy < nnu; dy++) {
            int y = cy + nus[dy];
            if (y + grid->low[lev].y < grid->ypart.local_lo[lev] || y + grid->low[lev].y > grid->ypart.local_hi[lev])
              continue;
            real phixy = phix * phi1d[dy].y;
            for (int dz = 0; dz < nnu; dz++) {
              int z = cz + nus[dz];
              if (z + grid->low[lev].z < grid->zpart.local_lo[lev] || z + grid->low[lev].z > grid->zpart.local_hi[lev])
                continue;
              real phixyz = phixy * phi1d[dz].z;
              // if (i == 0 && j == 0 && k == 0) {
              //   //puts("Here");
              // }
              // if (x == -grid->low[lev].x && y == -grid->low[lev].y && z == -grid->low[lev].z + 1) {
              //   printf("%d %d %d %f %f\n", i, j, k, etmp * phixyz, elower[(x * grid->nall[lev].y + y) * grid->nall[lev].z + z] + etmp * phixyz);
              // }
              elower[(x * grid->nall[lev].y + y) * grid->nall[lev].z + z] += etmp * phixyz;
            }
          }
        }
      }
    }
  }
}

void cell2qgrid(cellgrid_t *cgrid, msm_grid_t *mgrid) {
  int order = mgrid->order;
  vec<real> delinv = {mgrid->xpart.delinv[0], mgrid->ypart.delinv[0], mgrid->zpart.delinv[0]};
  vec<real> box_lo = vecarr(mgrid->gbox.lo);

  vec<int> low = vecarr(mgrid->low[0]);

  vec<int> nall = vecarr(mgrid->nall[0]);
  //vec<int>nall;
  real *qgrid = mgrid->q[0];
  memset(qgrid, 0, sizeof(real) * vecvol(nall));
  int nlower = -(order - 1) / 2;
  int nupper = order / 2;
  double qsqsum = 0;
  
  FOREACH_LOCAL_CELL(cgrid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      vec<real> xi;
      vecaddv(xi, cell->basis, cell->x[i]);
      real qi = cell->q[i];
      qsqsum += qi * qi;
      vec<real> x_rel;
      vecsubv(x_rel, xi, box_lo);
      vecmulv(x_rel, x_rel, delinv);
      vec<int> ix_rel;
      vecfloor(ix_rel, x_rel);
      vec<real> dx;
      vecsubv(dx, ix_rel, x_rel);
      vec<int> offset;
      vecsubv(offset, ix_rel, low);
      vec<real> phis[MAX_ORDER];
      compute_phis(phis, &dx, order);

      real x0 = qi;
      for (int dx = nlower; dx <= nupper; dx++) {
        int cur_x = offset.x + dx;
        real y0 = x0 * phis[dx - nlower].x;
        for (int dy = nlower; dy <= nupper; dy++) {
          int cur_y = offset.y + dy;
          real z0 = y0 * phis[dy - nlower].y;
          for (int dz = nlower; dz <= nupper; dz++) {
            int cur_z = offset.z + dz;
            assert(cur_x >= 0 && cur_x < nall.x && cur_y >= 0 && cur_y < nall.y && cur_z >= 0 && cur_z < nall.z);
            real phi = z0 * phis[dz - nlower].z;
            qgrid[(cur_x * nall.y + cur_y) * nall.z + cur_z] += phi;
          }
        }
      }
    }
  }
  mgrid->qsqsum = qsqsum;
}

void egrid2cell(cellgrid_t *cgrid, msm_grid_t *mgrid) {
  int order = mgrid->order;
  vec<real> delinv = {mgrid->xpart.delinv[0], mgrid->ypart.delinv[0], mgrid->zpart.delinv[0]};
  vec<real> box_lo = vecarr(mgrid->gbox.lo);

  vec<int> low = vecarr(mgrid->low[0]);

  vec<int> nall = vecarr(mgrid->nall[0]);
  //vec<int>nall;
  real *egrid = mgrid->e[0];
  //memset(qgrid, 0, sizeof(real) * vecvol(nall));
  int nlower = -(order - 1) / 2;
  int nupper = order / 2;

  double qsqsum = 0;
  FOREACH_LOCAL_CELL(cgrid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      cell->f[i].x = cell->f[i].y = cell->f[i].z = 0;
      vec<real> xi;
      vecaddv(xi, cell->basis, cell->x[i]);
      real qi = cell->q[i];
      qsqsum += qi * qi;
      vec<real> x_rel;
      vecsubv(x_rel, xi, box_lo);
      vecmulv(x_rel, x_rel, delinv);
      vec<int> ix_rel;
      vecfloor(ix_rel, x_rel);
      vec<real> dx;
      vecsubv(dx, ix_rel, x_rel);
      vec<int> offset;
      vecsubv(offset, ix_rel, low);
      vec<real> phis[MAX_ORDER], dphis[MAX_ORDER];
      compute_phis(phis, &dx, order);
      compute_dphis(dphis, &dx, order);
      vec<real> e = {0, 0, 0};
      //real x0 = q[i];
      for (int dx = nlower; dx <= nupper; dx++) {
        int cur_x = offset.x + dx;
        real phix = phis[dx - nlower].x;
        real dphix = dphis[dx - nlower].x;
        for (int dy = nlower; dy <= nupper; dy++) {
          int cur_y = offset.y + dy;
          real phiy = phis[dy - nlower].y;
          real dphiy = dphis[dy - nlower].y;
          for (int dz = nlower; dz <= nupper; dz++) {
            int cur_z = offset.z + dz;
            real phiz = phis[dz - nlower].z;
            real dphiz = dphis[dz - nlower].z;
            assert(cur_x >= 0 && cur_x < nall.x && cur_y >= 0 && cur_y < nall.y && cur_z >= 0 && cur_z < nall.z);
            real etmp = egrid[(cur_x * nall.y + cur_y) * nall.z + cur_z];
            e.x += dphix * phiy * phiz * etmp;
            e.y += phix * dphiy * phiz * etmp;
            e.z += phix * phiy * dphiz * etmp;
          }
        }
      }
      vecmulv(e, e, delinv);
      real qfactor = mgrid->coul_const * qi;
      vecscaleaddv(cell->f[i], cell->f[i], e, 1, qfactor);
    }
  }
}

void elocalgrid2cell(cellgrid_t *cgrid, msm_grid_t *mgrid) {
  int order = mgrid->order;
  vec<real> delinv = {mgrid->xpart.delinv[0], mgrid->ypart.delinv[0], mgrid->zpart.delinv[0]};
  vec<real> box_lo = vecarr(mgrid->gbox.lo);

  vec<int> low = vecarr(mgrid->low[0]);

  vec<int> nall = vecarr(mgrid->nall[0]);
  //vec<int>nall;
  real *egrid = mgrid->e[0];
  //memset(qgrid, 0, sizeof(real) * vecvol(nall));
  int nlower = -(order - 1) / 2;
  int nupper = order / 2;

  FOREACH_CELL(cgrid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      // cell->f[i].x = cell->f[i].y = cell->f[i].z = 0;
      vec<real> xi;
      vecaddv(xi, cell->basis, cell->x[i]);
      real qi = cell->q[i];
      vec<real> x_rel;
      vecsubv(x_rel, xi, box_lo);
      vecmulv(x_rel, x_rel, delinv);
      vec<int> ix_rel;
      vecfloor(ix_rel, x_rel);
      vec<real> dx;
      vecsubv(dx, ix_rel, x_rel);
      vec<int> offset;
      vecsubv(offset, ix_rel, low);
      
      vec<real> phis[MAX_ORDER], dphis[MAX_ORDER];
      compute_phis(phis, &dx, order);
      compute_dphis(dphis, &dx, order);
      vec<real> e = {0, 0, 0};
      //real x0 = q[i];
      for (int dx = nlower; dx <= nupper; dx++) {
        if (ix_rel.x + dx < mgrid->xpart.local_lo[0] || ix_rel.x + dx > mgrid->xpart.local_hi[0]) continue;
        int cur_x = offset.x + dx;
        real phix = phis[dx - nlower].x;
        real dphix = dphis[dx - nlower].x;
        for (int dy = nlower; dy <= nupper; dy++) {
          if (ix_rel.y + dy < mgrid->ypart.local_lo[0] || ix_rel.y + dy > mgrid->ypart.local_hi[0]) continue;
          int cur_y = offset.y + dy;
          real phiy = phis[dy - nlower].y;
          real dphiy = dphis[dy - nlower].y;
          for (int dz = nlower; dz <= nupper; dz++) {
            if (ix_rel.z + dz < mgrid->zpart.local_lo[0] || ix_rel.z + dz > mgrid->zpart.local_hi[0]) continue;
            int cur_z = offset.z + dz;
            real phiz = phis[dz - nlower].z;
            real dphiz = dphis[dz - nlower].z;
            assert(cur_x >= 0 && cur_x < nall.x && cur_y >= 0 && cur_y < nall.y && cur_z >= 0 && cur_z < nall.z);
            real etmp = egrid[(cur_x * nall.y + cur_y) * nall.z + cur_z];
            e.x += dphix * phiy * phiz * etmp;
            e.y += phix * dphiy * phiz * etmp;
            e.z += phix * phiy * dphiz * etmp;
          }
        }
      }
      vecmulv(e, e, delinv);
      real qfactor = mgrid->coul_const * qi;
      vecscaleaddv(cell->f[i], cell->f[i], e, 1, qfactor);
    }
  }

}
void msm_compute(cellgrid_t *cgrid, msm_grid_t *mgrid, void *param, mdstat_t *stat) {
  mpp_t *mpp = mgrid->mpp;
  mdstat_t tmp_stat;
  memset(&tmp_stat, 0, sizeof(mdstat_t));
  cell2qgrid(cgrid, mgrid);
  reverse_q(mgrid, mpp, mgrid->q[0], 0);

  for (int i = 0; i < mgrid->maxlev; i++) {
    if (i > 0) {
      restriction(mgrid, i - 1);
      reverse_q(mgrid, mpp, mgrid->q[i], i);
    }
    forward_q(mgrid, mpp, mgrid->q[i], i);
    direct(mgrid, i, &tmp_stat);
  }
  
  for (int i = mgrid->maxlev - 2; i >= 0; i--) {
    forward_e(mgrid, mpp, mgrid->e[i + 1], i + 1);
    prolong(mgrid, i);
  }
  vec<real> delinv = {mgrid->xpart.delinv[0], mgrid->ypart.delinv[0], mgrid->zpart.delinv[0]};
  int nupper = mgrid->order / 2;
  vec<real> lhalo, dist;
  vecset1(dist, nupper);
  vecdivv(lhalo, dist, delinv);
  //printf("%f %f %f\n", lhalo.x, lhalo.y, lhalo.z);
  //If halo is small enough we can reverse f, else we must forward e
  if (lhalo.x < mgrid->rcut && lhalo.y < mgrid->rcut && lhalo.z < mgrid->rcut) {
    elocalgrid2cell(cgrid, mgrid);
  } else {
    forward_e(mgrid, mpp, mgrid->e[0], 0);
    egrid2cell(cgrid, mgrid);
  }
  double e_self = mgrid->qsqsum * msm_gamma(0, mgrid->order) / mgrid->rcut;
  stat->ecoul += (tmp_stat.ecoul - e_self) * 0.5 * mgrid->coul_const;
  stat->virial[0] += 0.5 * mgrid->coul_const * tmp_stat.virial[0];
  stat->virial[1] += 0.5 * mgrid->coul_const * tmp_stat.virial[1];
  stat->virial[2] += 0.5 * mgrid->coul_const * tmp_stat.virial[2];
  stat->virial[3] += 0.5 * mgrid->coul_const * tmp_stat.virial[3];
  stat->virial[4] += 0.5 * mgrid->coul_const * tmp_stat.virial[4];
  stat->virial[5] += 0.5 * mgrid->coul_const * tmp_stat.virial[5];
  //printf("%f %f %f %f %f %f %f\n", stat->ecoul, stat->virial[0], stat->virial[1], stat->virial[2], stat->virial[3], stat->virial[4], stat->virial[5]);
}

#ifdef MSM_TEST
void particle2grid(msm_grid_t *grid, vec<real> * x, real *q, int n) {

  int order = grid->order;
  vec<real> delinv = {grid->xpart.delinv[0], grid->ypart.delinv[0], grid->zpart.delinv[0]};
  vec<real> box_lo = vecarr(grid->gbox.lo);

  vec<int> low = vecarr(grid->low[0]);

  vec<int> nall = vecarr(grid->nall[0]);
  //vec<int>nall;
  real *qgrid = grid->q[0];
  memset(qgrid, 0, sizeof(real) * vecvol(nall));
  int nlower = -(order - 1) / 2;
  int nupper = order / 2;
  for (int i = 0; i < n; i++) {
    if (!vec_in_box(x[i], grid->lbox))
      continue;
    vec<real> x_rel;
    vecsubv(x_rel, x[i], box_lo);
    vecmulv(x_rel, x_rel, delinv);
    vec<int> ix_rel;
    vecfloor(ix_rel, x_rel);
    vec<real> dx;
    vecsubv(dx, ix_rel, x_rel);
    vec<int> offset;
    vecsubv(offset, ix_rel, low);
    vec<real> phis[MAX_ORDER];
    compute_phis(phis, &dx, order);

    real x0 = q[i];
    for (int dx = nlower; dx <= nupper; dx++) {
      int cur_x = offset.x + dx;
      real y0 = x0 * phis[dx - nlower].x;
      for (int dy = nlower; dy <= nupper; dy++) {
        int cur_y = offset.y + dy;
        real z0 = y0 * phis[dy - nlower].y;
        for (int dz = nlower; dz <= nupper; dz++) {
          int cur_z = offset.z + dz;
          assert(cur_x >= 0 && cur_x < nall.x && cur_y >= 0 && cur_y < nall.y && cur_z >= 0 && cur_z < nall.z);
          real phi = z0 * phis[dz - nlower].z;
          qgrid[(cur_x * nall.y + cur_y) * nall.z + cur_z] += phi;
        }
      }
    }
  }
}

void grid2particle(msm_grid_t *grid, vec<real> * f, vec<real> * x, real *q, int n, real coul_const) {
  int order = grid->order;
  vec<real> delinv = {grid->xpart.delinv[0], grid->ypart.delinv[0], grid->zpart.delinv[0]};
  vec<real> box_lo = vecarr(grid->gbox.lo);

  vec<int> low = vecarr(grid->low[0]);

  vec<int> nall = vecarr(grid->nall[0]);
  //vec<int>nall;
  real *egrid = grid->e[0];
  //memset(qgrid, 0, sizeof(real) * vecvol(nall));
  int nlower = -(order - 1) / 2;
  int nupper = order / 2;
  for (int i = 0; i < n; i++) {
    if (!vec_in_box(x[i], grid->lbox))
      continue;
    vec<real> x_rel;
    vecsubv(x_rel, x[i], box_lo);
    vecmulv(x_rel, x_rel, delinv);
    vec<int> ix_rel;
    vecfloor(ix_rel, x_rel);
    vec<real> dx;
    vecsubv(dx, ix_rel, x_rel);
    vec<int> offset;
    vecsubv(offset, ix_rel, low);
    vec<real> phis[MAX_ORDER], dphis[MAX_ORDER];
    compute_phis(phis, &dx, order);
    compute_dphis(dphis, &dx, order);
    vec<real> e = {0, 0, 0};
    //real x0 = q[i];
    for (int dx = nlower; dx <= nupper; dx++) {
      int cur_x = offset.x + dx;
      real phix = phis[dx - nlower].x;
      real dphix = dphis[dx - nlower].x;
      //real y0 = x0 * phis[dx - nlower].x;
      for (int dy = nlower; dy <= nupper; dy++) {
        int cur_y = offset.y + dy;
        real phiy = phis[dy - nlower].y;
        real dphiy = dphis[dy - nlower].y;
        //real z0 = y0 * phis[dy - nlower].y;
        for (int dz = nlower; dz <= nupper; dz++) {
          int cur_z = offset.z + dz;
          real phiz = phis[dz - nlower].z;
          real dphiz = dphis[dz - nlower].z;
          assert(cur_x >= 0 && cur_x < nall.x && cur_y >= 0 && cur_y < nall.y && cur_z >= 0 && cur_z < nall.z);
          //real phi = z0 * phis[dz - nlower].z;
          real etmp = egrid[(cur_x * nall.y + cur_y) * nall.z + cur_z];
          e.x += dphix * phiy * phiz * etmp;
          e.y += phix * dphiy * phiz * etmp;
          e.z += phix * phiy * dphiz * etmp;
        }
      }
    }
    vecmulv(e, e, delinv);
    real qfactor = coul_const * q[i];
    vecscaleaddv(f[i], f[i], e, 1, qfactor);
  }
}
#define MAXN 10
#define N 1
int main(int argc, char **argv) {
  memory_init();
  mdstat_t stat = {0, 0, 0, 0, 0, 0, 0, 0};
  vec<double> *xbin = malloc(sizeof(double) * 7051 * 3);
  double *qbin = malloc(sizeof(double) * 7051);
  FILE *fx = fopen("x.bin", "rb");
  fread(xbin, sizeof(double), 7051 * 3, fx);
  FILE *fq = fopen("q.bin", "rb");
  fread(qbin, sizeof(double), 7051, fq);
  fclose(fx);
  fclose(fq);

  vec<real> *x = malloc(sizeof(double) * 3 * 7051 * MAXN);
  real *q = malloc(sizeof(double) * 7051 * MAXN);
  for (int i = 0; i < 7051; i++) {
    veccpy(x[i], xbin[i]);
    q[i] = qbin[i];
  }
  for (int n = 1; n < N; n++) {
    for (int i = 0; i < 7051; i++) {
      veccpy(x[i + 7051 * n], xbin[i]);
      x[i + 7051 * n].z += 23.5 * 2;
      q[i + 7051 * n] = qbin[i];
    }
  }
  vec<int> nlev = {4, 4, 4};
  box<real> gbox;
  gbox.lo.x = -21;
  gbox.lo.y = -22;
  gbox.lo.z = -23.5;
  gbox.hi.x = 21;
  gbox.hi.y = 22;
  gbox.hi.z = 23.5 * (2 * N - 1);
  vec<real> skin = {2, 2, 2};
  MPI_Init(&argc, &argv);
  mpp_t mpp;
  comm_init(&mpp, MPI_COMM_WORLD, &gbox);
  cellgrid_t cgrid;
  bond_graph_t graph;
  impr_index_t imidx;
  build_cells(&cgrid, 2, 12, 2, &mpp.lbox, 0, NULL, NULL, NULL, &graph, &imidx, NULL, NULL);
  comm_init_buf(&mpp, &cgrid);
  msm_grid_t grid;
  msm_grid_init(&grid, &mpp, &nlev, 4, 12, &cgrid.skin);
  particle2grid(&grid, x, q, 7051 * N);

  reverse_q(&grid, &mpp, grid.q[0], 0);

  char fn[30];
  FILE *qgrid;
  sprintf(fn, "qgrid%d_rev%d.bin", 0, mpp.pid);
  qgrid = fopen(fn, "wb");
  fwrite(grid.q[0], sizeof(real), vecvol(grid.nall[0]), qgrid);
  fclose(qgrid);

  for (int i = 0; i < grid.maxlev; i++) {
    if (i > 0) {
      restriction(&grid, i - 1);

      reverse_q(&grid, &mpp, grid.q[i], i);
      printf("%d: %d %d %d\n", i, grid.nall[i].x, grid.nall[i].y, grid.nall[i].z);
      printf("%d: %d %d %d\n", i, grid.low[i].x, grid.low[i].y, grid.low[i].z);
      sprintf(fn, "qgrid%d_rev%d.bin", i, mpp.pid);
      FILE *qgrid;
      qgrid = fopen(fn, "wb");
      fwrite(grid.q[i], sizeof(real), vecvol(grid.nall[i]), qgrid);
      fclose(qgrid);
      // return;
    }
    forward_q(&grid, &mpp, grid.q[i], i);
    sprintf(fn, "qgrid%d_fwd%d.bin", i, mpp.pid);
    qgrid = fopen(fn, "wb");
    fwrite(grid.q[i], sizeof(real), vecvol(grid.nall[i]), qgrid);
    fclose(qgrid);
    direct(&grid, i, &stat);
    //MPI_Finalize();
    //return 0;
  }
  for (int i = grid.maxlev - 2; i >= 0; i--) {
    puts("forward e");
    forward_e(&grid, &mpp, grid.e[i + 1], i + 1);
    prolong(&grid, i);
    sprintf(fn, "egrid%d_%d.bin", i, mpp.pid);
    FILE *egrid = fopen(fn, "wb");
    fwrite(grid.e[i], sizeof(real), vecvol(grid.nall[i]), egrid);
    fclose(egrid);
  }
  forward_e(&grid, &mpp, grid.e[0], 0);
  sprintf(fn, "egrid%d_%d.bin", 0, mpp.pid);
  FILE *egrid = fopen(fn, "wb");
  fwrite(grid.e[0], sizeof(real), vecvol(grid.nall[0]), egrid);
  fclose(egrid);
  vec<real> *f = malloc(sizeof(real) * 3 * 7051 * MAXN);
  memset(f, 0, sizeof(real) * 3 * 7051 * N);
  vec<real> *fsum = malloc(sizeof(real) * 3 * 7051 * MAXN);

  grid2particle(&grid, f, x, q, 7051 * N, 332.0636);
  MPI_Allreduce(f, fsum, 7051 * N * 3, mpi_real, MPI_SUM, MPI_COMM_WORLD);
  if (mpp.pid == 0) {
    FILE *ff = fopen("f.bin", "wb");
    fwrite(fsum, sizeof(real), 3 * 7051 * N, ff);
  }
  printf("%g %g %g %g %g %g %g\n", stat.ecoul, stat.virial[0], stat.virial[1], stat.virial[2], stat.virial[3], stat.virial[4], stat.virial[5]);
  MPI_Finalize();
}
#endif
